{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from torch.utils import data\n",
    "import torch.optim as optim\n",
    "from torch import nn\n",
    "from torch.nn import functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CNNNet(nn.Module): # AlexNet implementation!\n",
    "    def __init__(self, num_classes=2):\n",
    "        super(CNNNet, self).__init__()\n",
    "        # Sequential allows us to create chains\n",
    "        self.features = nn.Sequential(\n",
    "                            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), # in_channels,out_channels, kernel_size, stride, padding\n",
    "                            nn.ReLU(),\n",
    "                            nn.MaxPool2d(kernel_size=3, stride=2), # can use padding with pooling.\n",
    "                            nn.Conv2d(64, 192, kernel_size=5, padding=2), # don't have to set padding, PyTorch can handle it.\n",
    "                            nn.ReLU(),\n",
    "                            nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "                            nn.Conv2d(192, 384, kernel_size=3, padding=1),\n",
    "                            nn.ReLU(),\n",
    "                            nn.Conv2d(384, 256, kernel_size=3, padding=1),\n",
    "                            nn.ReLU(),\n",
    "                            nn.Conv2d(256, 256, kernel_size=3, padding=1),\n",
    "                            nn.ReLU(),\n",
    "                            nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "                        )\n",
    "        self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) # works independently of the incoming input tensor’s dimensions, returns 6x6 pixels.\n",
    "        self.classifier = nn.Sequential(\n",
    "                            nn.Dropout(), # default dropout rate is 0.5\n",
    "                            nn.Linear(256 * 6 * 6, 4096),\n",
    "                            nn.ReLU(),\n",
    "                            nn.Dropout(), # nn.Dropout(p=0.2)\n",
    "                            # Note for Dropout: We strongly need to use model.train() or model.eval() as to inform the model.\n",
    "                            # If we don't do that, our model can use dropout technique on testing.\n",
    "                            nn.Linear(4096, 4096),\n",
    "                            nn.ReLU(),\n",
    "                            nn.Linear(4096, num_classes)\n",
    "                        )\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = self.avgpool(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.classifier(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms = transforms.Compose([\n",
    "    transforms.Resize((64,64)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_path = \"./train/\"\n",
    "train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=transforms)\n",
    "\n",
    "val_data_path = \"./val/\"\n",
    "val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=transforms)\n",
    "\n",
    "test_data_path = \"./test/\"\n",
    "test_data = torchvision.datasets.ImageFolder(root=test_data_path,transform=transforms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size=64\n",
    "train_data_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)\n",
    "val_data_loader = data.DataLoader(val_data, batch_size=batch_size)\n",
    "test_data_loader = data.DataLoader(test_data, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "alexnet = CNNNet()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(alexnet.parameters(), lr=0.001)\n",
    "loss_fn = torch.nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device=\"cpu\"):\n",
    "    for epoch in range(epochs):\n",
    "        training_loss = 0.0\n",
    "        valid_loss = 0.0\n",
    "        model.train()   # training mode\n",
    "        for batch in train_loader:\n",
    "            optimizer.zero_grad()\n",
    "            inputs, targets = batch\n",
    "            inputs = inputs.to(device)\n",
    "            targets = targets.to(device)\n",
    "            output = model(inputs)\n",
    "            loss = loss_fn(output, targets)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            training_loss += loss.data.item() * inputs.size(0)\n",
    "        training_loss /= len(train_loader.dataset)\n",
    "        \n",
    "        model.eval()   # evaluation mode for test.\n",
    "        num_correct = 0\n",
    "        num_examples = 0\n",
    "        for batch in val_loader:\n",
    "            inputs, targets = batch\n",
    "            inputs = inputs.to(device)\n",
    "            output = model(inputs)\n",
    "            targets = targets.to(device)\n",
    "            loss = loss_fn(output,targets)\n",
    "            valid_loss += loss.data.item() * inputs.size(0)\n",
    "            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets).view(-1)\n",
    "            num_correct += torch.sum(correct).item()\n",
    "            num_examples += correct.shape[0]\n",
    "        valid_loss /= len(val_loader.dataset)\n",
    "        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss, valid_loss, num_correct / num_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0, Training Loss: 1.04, Validation Loss: 0.53, accuracy = 0.80\n",
      "Epoch: 1, Training Loss: 0.68, Validation Loss: 0.50, accuracy = 0.80\n",
      "Epoch: 2, Training Loss: 0.62, Validation Loss: 0.57, accuracy = 0.58\n",
      "Epoch: 3, Training Loss: 0.59, Validation Loss: 0.58, accuracy = 0.72\n",
      "Epoch: 4, Training Loss: 0.52, Validation Loss: 0.51, accuracy = 0.75\n",
      "Epoch: 5, Training Loss: 0.50, Validation Loss: 0.59, accuracy = 0.58\n",
      "Epoch: 6, Training Loss: 0.49, Validation Loss: 0.64, accuracy = 0.56\n",
      "Epoch: 7, Training Loss: 0.45, Validation Loss: 0.37, accuracy = 0.82\n",
      "Epoch: 8, Training Loss: 0.41, Validation Loss: 0.31, accuracy = 0.85\n",
      "Epoch: 9, Training Loss: 0.43, Validation Loss: 0.55, accuracy = 0.59\n",
      "Epoch: 10, Training Loss: 0.36, Validation Loss: 0.46, accuracy = 0.72\n",
      "Epoch: 11, Training Loss: 0.35, Validation Loss: 0.33, accuracy = 0.83\n",
      "Epoch: 12, Training Loss: 0.34, Validation Loss: 0.30, accuracy = 0.87\n",
      "Epoch: 13, Training Loss: 0.31, Validation Loss: 0.48, accuracy = 0.80\n",
      "Epoch: 14, Training Loss: 0.30, Validation Loss: 0.63, accuracy = 0.67\n",
      "Epoch: 15, Training Loss: 0.39, Validation Loss: 0.42, accuracy = 0.79\n",
      "Epoch: 16, Training Loss: 0.29, Validation Loss: 0.85, accuracy = 0.61\n",
      "Epoch: 17, Training Loss: 0.36, Validation Loss: 0.29, accuracy = 0.86\n",
      "Epoch: 18, Training Loss: 0.25, Validation Loss: 0.78, accuracy = 0.63\n",
      "Epoch: 19, Training Loss: 0.25, Validation Loss: 0.41, accuracy = 0.82\n"
     ]
    }
   ],
   "source": [
    "train(alexnet, optimizer, loss_fn, train_data_loader, val_data_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model accuracy on test data: 82.1%\n"
     ]
    }
   ],
   "source": [
    "corrects = 0\n",
    "total = 0\n",
    "for batch in test_data_loader:\n",
    "    inputs, targets = batch\n",
    "    prediction = alexnet(inputs)\n",
    "    correct = torch.eq(torch.max(F.softmax(prediction, dim=1), dim=1)[1], targets).view(-1)\n",
    "    correct_num = torch.sum(correct)\n",
    "    total_instance = correct.size()[0]\n",
    "    corrects += correct_num.item()\n",
    "    total += total_instance\n",
    "print(\"Model accuracy on test data: {}%\".format(round(corrects/total*100,2)))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
